import os
import sys
import gym
import numpy as np
import d4rl
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
from scipy.stats import pearsonr
from typing import Dict, Tuple, Optional

sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import reward_design.utils


class PreferenceModel:
    def __init__(
        self,
        env_name,
        design_mode,
        model_seed,
        num_top_episodes,
        expert_tolerance,
        noise_tolerance,
        obs_noise_scale,
        action_noise_scale,
        synthetic_trajs_num=10000,
        use_return_per_step=False,
        fix_antmaze_timeout=True,
        disable_goal=False,
        fix_goal=True,
    ):
        self.env_name = env_name
        self.env = gym.make(env_name)
        self.design_mode = design_mode
        self.num_top_episodes = num_top_episodes
        self.expert_tolerance = expert_tolerance
        self.noise_tolerance = noise_tolerance
        self.obs_noise_scale = obs_noise_scale
        self.action_noise_scale = action_noise_scale
        self.use_return_per_step = use_return_per_step
        self.fix_antmaze_timeout = fix_antmaze_timeout
        self.disable_goal = disable_goal
        self.synthetic_trajs_num = synthetic_trajs_num
        self.fix_goal = fix_goal

        # set seeds
        self.env.seed(model_seed)
        self.env.action_space.seed(model_seed)
        reward_design.utils.set_seed_everywhere(seed=model_seed, using_cuda=True)

        # load data
        self.trajs = self.load_trajectories()
        # Sort by traj return in descending order.
        (
            self.trajs_all,
            self.sorted_indices,
            self.sorted_trajs,
            self.sorted_returns,
            self.sorted_lens,
        ) = self.get_sorted_trajs()
        # print(len(self.sorted_indices), len(self.sorted_trajs), len(self.sorted_returns))
        # print(self.sorted_indices[:10])
        # print(self.sorted_returns[:10])

    def compute_scores(self, reward_code: str):
        """compute scores of the input reward function

        Args:
            reward_code (str): reward function code

        Returns:
            _type_: _description_
        """
        # load reward function code
        # extra_package = {"Dict": Dict, "Tuple": Tuple}
        # globals_dict = globals()
        # globals_dict.update(extra_package)
        # exec(reward_code, globals_dict)
        local_dict = {**globals(), **{"Dict": Dict, "Tuple": Tuple, "Optional": Optional}}
        exec(reward_code, local_dict)
        # print(local_dict["compute_dense_reward"])
        # print(globals()["compute_dense_reward"])

        # calculate score
        expert_returns, non_expert_returns = [], []
        # expert_lens, non_expert_lens = [], []


        # for evaluation
        # obs_action_list, reward_list = []
        step_env_rewards = []
        step_cal_rewards = []
        env_returns = []
        cal_returns = []
        for i, traj in enumerate(self.sorted_trajs):
            traj_env_return = 0
            traj_cal_return = 0
            for frame in traj:
                obs, action, env_reward, mask, done_float, next_obs = frame
                # if "antmaze" in self.env_name:
                #     obs = np.append(obs, done_float)
                #     next_obs = np.append(next_obs, done_float)
                # reward, reward_dict = globals_dict["compute_dense_reward"](obs, action)
                if self.design_mode == "sa":
                    reward = local_dict["compute_dense_reward"](obs, action)
                elif self.design_mode == "sas":
                    reward = local_dict["compute_dense_reward"](obs, action, next_obs)
                elif self.design_mode == "ss":
                    reward = local_dict["compute_dense_reward"](obs, next_obs)
                else:
                    raise NotImplementedError

                traj_env_return += env_reward
                traj_cal_return += reward

                # for evaluation
                # obs_action_list.append(np.concatenate([obs, action]))
                # reward_list.append((env_reward, reward))
                step_env_rewards.append(env_reward)
                step_cal_rewards.append(reward)

            env_returns.append(traj_env_return)
            cal_returns.append(traj_cal_return)
            # if the indice of traj is the Top K
            if i < self.num_top_episodes:
                expert_returns.append(traj_cal_return)
                # expert_lens.append(self.sorted_lens[i])
            else:
                non_expert_returns.append(traj_cal_return)
                # non_expert_lens.append(self.sorted_lens[i])


        # print("threshold max:", min(expert_returns))
        # print("cal expert:", expert_returns)
        # print("cal non expert:")
        # print(non_expert_returns[:20])
        # print(np.array(non_expert_returns[:20]) > min(expert_returns))

        # threshold = max(non_expert_returns)
        threshold = min(expert_returns)
        if threshold >= 0:
            threshold = (1 + self.expert_tolerance) * threshold
        else:
            threshold = (1 - self.expert_tolerance) * threshold

        count_greater = sum(
            # 1 for val in expert_returns if val >= (1 - self.error_margin) * threshold
            1 for val in non_expert_returns if threshold >= val
        )
        # total_num = len(expert_returns)
        total_num = len(non_expert_returns)
        score_expert = count_greater / total_num if total_num > 0 else 0.0


        # evaluation
        # obs_action_array = np.array(obs_action_list)
        # reward_array = np.array(reward_list)
        # pca = PCA(n_components=1)
        # obs_action_array = pca.fit_transform(obs_action_array)
        step_corr, step_p = pearsonr(
            step_env_rewards, step_cal_rewards
        )
        print(f"Step-level Pearson correlation coefficient: {step_corr:.3f}, p value: {step_p:.3g}")

        # traj_corr, traj_p = pearsonr(
        #     env_returns, cal_returns
        # )
        # print(f"Trajectory-level Pearson correlation coefficient: {traj_corr:.3f}, p value: {traj_p:.3g}")

        expert_trajs_returns, noisy_trajs_returns = [], []
        for i, traj in enumerate(self.trajs_all):
            traj_return = 0
            for frame in traj:
                obs, action, env_reward, mask, done_float, next_obs = frame
                # if "antmaze" in self.env_name:
                #     obs = np.append(obs, done_float)
                #     next_obs = np.append(next_obs, done_float)
                # reward, reward_dict = globals_dict["compute_dense_reward"](obs, action)
                if self.design_mode == "sa":
                    reward = local_dict["compute_dense_reward"](obs, action)
                elif self.design_mode == "sas":
                    reward = local_dict["compute_dense_reward"](obs, action, next_obs)
                elif self.design_mode == "ss":
                    reward = local_dict["compute_dense_reward"](obs, next_obs)
                else:
                    raise NotImplementedError

                traj_return += reward

            # if the indice of traj is the Top K
            if i < self.num_top_episodes:
                expert_trajs_returns.append(traj_return)
            else:
                noisy_trajs_returns.append(traj_return)

        threshold = min(expert_trajs_returns)
        if threshold >= 0:
            threshold = (1 + self.noise_tolerance) * threshold
        else:
            threshold = (1 - self.noise_tolerance) * threshold
        count_greater = sum(
            1 for val in noisy_trajs_returns if threshold > val
        )
        # total_num = len(expert_returns)
        total_num = len(noisy_trajs_returns)
        score_noisy = count_greater / total_num if total_num > 0 else 0.0

        score = (score_expert + score_noisy) / 2

        return score, {
            "step_corr": step_corr,
            "step_p": step_p,
            # "traj_corr": traj_corr,
            # "traj_p": traj_p,
            "score": score,
            "score_expert": score_expert,
            "score_noisy": score_noisy,
            "max_return": max(cal_returns),
            "min_return": min(cal_returns),
        }

    def load_trajectories(self):
        if "antmaze" in self.env_name and self.fix_antmaze_timeout:
            print(
                "=" * 10
                + f"Env name={self.env_name}. Use qlearning_dataset_with_timeouts."
                + "=" * 10
            )
            dataset = self.qlearning_dataset_with_timeouts(self.env)
        else:
            print(
                "=" * 10
                + f"Env name={self.env_name}. Use d4rl.qlearning_dataset."
                + "=" * 10
            )
            dataset = d4rl.qlearning_dataset(self.env)

        dones_float = np.zeros_like(dataset["rewards"])

        for i in range(len(dones_float) - 1):
            if (
                np.linalg.norm(
                    dataset["observations"][i + 1] - dataset["next_observations"][i]
                )
                > 1e-6
                or dataset["terminals"][i] == 1.0
            ):
                dones_float[i] = 1
            else:
                dones_float[i] = 0
        dones_float[-1] = 1

        if "realterminals" in dataset:
            # We updated terminals in the dataset, but continue using
            # the old terminals for consistency with original IQL.
            masks = 1.0 - dataset["realterminals"].astype(np.float32)
        else:
            masks = 1.0 - dataset["terminals"].astype(np.float32)
        traj = self.split_into_trajectories(
            observations=dataset["observations"].astype(np.float32),
            actions=dataset["actions"].astype(np.float32),
            rewards=dataset["rewards"].astype(np.float32),
            masks=masks,
            dones_float=dones_float.astype(np.float32),
            next_observations=dataset["next_observations"].astype(np.float32),
        )
        return traj

    def qlearning_dataset_with_timeouts(
        self, terminate_on_end=False, **kwargs
    ):
        dataset = self.env.get_dataset()

        N = dataset["rewards"].shape[0]
        obs_ = []
        next_obs_ = []
        action_ = []
        reward_ = []
        done_ = []
        realdone_ = []
        if "infos/goal" in dataset:
            if not self.disable_goal:
                if self.fix_goal:
                    goals = reward_design.utils.get_antmaze_fix_goals(self.env_name, dataset["observations"].shape[0])
                    dataset["observations"] = np.concatenate(
                        [dataset["observations"], goals], axis=1
                    )
                else:
                    dataset["observations"] = np.concatenate(
                        [dataset["observations"], dataset["infos/goal"]], axis=1
                    )
            else:
                pass

        episode_step = 0
        for i in range(N - 1):
            obs = dataset["observations"][i]
            new_obs = dataset["observations"][i + 1]
            action = dataset["actions"][i]
            reward = dataset["rewards"][i]
            done_bool = bool(dataset["terminals"][i])
            realdone_bool = bool(dataset["terminals"][i])
            if "infos/goal" in dataset:
                final_timestep = (
                    True
                    if (dataset["infos/goal"][i] != dataset["infos/goal"][i + 1]).any()
                    else False
                )
            else:
                final_timestep = dataset["timeouts"][i]

            if i < N - 1:
                done_bool += final_timestep

            if (not terminate_on_end) and final_timestep:
                # Skip this transition and don't apply terminals on the last step of an episode
                episode_step = 0
                continue
            if done_bool or final_timestep:
                episode_step = 0

            obs_.append(obs)
            next_obs_.append(new_obs)
            action_.append(action)
            reward_.append(reward)
            done_.append(done_bool)
            realdone_.append(realdone_bool)
            episode_step += 1

        return {
            "observations": np.array(obs_),
            "actions": np.array(action_),
            "next_observations": np.array(next_obs_),
            "rewards": np.array(reward_)[:],
            "terminals": np.array(done_)[:],
            "realterminals": np.array(realdone_)[:],
        }

    def split_into_trajectories(
        self, observations, actions, rewards, masks, dones_float, next_observations
    ):
        trajs = [[]]

        for i in range(len(observations)):
            trajs[-1].append(
                (
                    observations[i],
                    actions[i],
                    rewards[i],
                    masks[i],
                    dones_float[i],
                    next_observations[i],
                )
            )
            if dones_float[i] == 1.0 and i + 1 < len(observations):
                trajs.append([])

        return trajs

    def get_sorted_trajs(self):
        """Load expert demonstrations."""
        # Load trajectories from the given dataset
        # if self.num_top_episodes < 0:
        #     print("Loading the entire dataset as demonstrations")
        #     return self.trajs

        if "antmaze" in self.env_name:
            returns = [
                sum([t[2] for t in traj]) / (1e-4 + np.linalg.norm(traj[0][0][:2]))
                for traj in self.trajs
            ]
        else:
            returns = [sum([t[2] for t in traj]) for traj in self.trajs]
        # idx = np.argpartition(returns, -self.num_top_episodes)[-self.num_top_episodes :]
        traj_lens = [len(traj) for traj in self.trajs]

        if self.use_return_per_step:
            returns = list(np.array(returns) / np.array(traj_lens))

        sorted_indices = np.argsort(returns)[::-1]
        sorted_trajs = [self.trajs[i] for i in sorted_indices]
        sorted_returns = [returns[i] for i in sorted_indices]
        sorted_lens = [traj_lens[i] for i in sorted_indices]
        expert_trajs = [self.trajs[i] for i in sorted_indices[: self.num_top_episodes]]
        noisy_trajs = self.generate_noisy_trajectories_auto_std(
            traj = expert_trajs[-1],
            m = self.synthetic_trajs_num - len(expert_trajs)
        )
        trajs_all = expert_trajs + noisy_trajs
        return trajs_all, list(sorted_indices), sorted_trajs, sorted_returns, sorted_lens


    def generate_noisy_trajectories_auto_std(self, traj, m):
        # extract all obs, action, next_obs for calculate std
        obs_list = []
        action_list = []
        next_obs_list = []
        
        for (obs, action, env_reward, mask, done_float, next_obs) in traj:
            obs_list.append(obs)
            action_list.append(action)
            next_obs_list.append(next_obs)
        
        obs_array = np.stack(obs_list)
        action_array = np.stack(action_list)
        next_obs_array = np.stack(next_obs_list)

        obs_std = np.std(obs_array, axis=0)
        action_std = np.std(action_array, axis=0)
        next_obs_std = np.std(next_obs_array, axis=0)

        obs_noise_std = obs_std * self.obs_noise_scale
        action_noise_std = action_std * self.action_noise_scale
        next_obs_noise_std = next_obs_std * self.obs_noise_scale  # same as obs_scale

        noisy_trajs = []
        for _ in range(m):
            noisy_traj = []
            # For goal-oriented tasks, the last frame may contain critical information and no randomization is performed.
            for (obs, action, env_reward, mask, done_float, next_obs) in traj[:-1]:
                noisy_obs = obs + np.random.normal(0, obs_noise_std)
                noisy_action = action + np.random.normal(0, action_noise_std)
                noisy_next_obs = next_obs + np.random.normal(0, next_obs_noise_std)
                
                noisy_traj.append([
                    noisy_obs,
                    noisy_action,
                    env_reward,
                    mask,
                    done_float,
                    noisy_next_obs
                ])
            noisy_trajs.append(noisy_traj)
        
        return noisy_trajs

